import matplotlib.pyplot as plt
import numpy as np
import tqdm

from src.imports.import_clustering_sets import import_clustering_sets
from src.imports.import_profile import (
    import_coordinates, import_radius,
)
from src.imports.import_profile import import_profile


if __name__ == "__main__":

    num_tests = 100
    num_voters = 1000
    num_candidates = 100
    lower_radius = 0.1
    upper_radius = 0.33

    num_clusters = 5

    method = 'hierarchical_minavg'

    base_distances = ['norm_hamming', 'geom_hamming', 'rms_hamming',
                      'jaccard', 'geom_jaccard', 'rms_jaccard']


    for base_distance in base_distances:
        sat = 0
        # IMPORT CLUSTERS
        c_sets_path = f"output/sampled/euclidean_2d/{method}_{base_distance}_c_sets_{num_clusters}_{num_voters}_{num_candidates}_{lower_radius}_{upper_radius}.txt"
        c_sets = import_clustering_sets(c_sets_path)

        for t in tqdm.tqdm(range(num_tests, )):
            path = f'data/sampled/euclidean_2d/profiles/profile_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}'
            P = import_profile(path)

            # print number of ones in P
            sat += np.sum(P)/num_voters/num_candidates


            # IMPORT COORDINATES
            c_points_path = f'data/sampled/euclidean_2d/coordinates/c_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            c_coordinates = import_coordinates(c_points_path)
            c_radius_path = f'data/sampled/euclidean_2d/coordinates/c_radius_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            c_radius = import_radius(c_radius_path)

            v_points_path = f'data/sampled/euclidean_2d/coordinates/v_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            v_coordinates = import_coordinates(v_points_path)


            # PLOT

            clusters = []
            for k in range(num_clusters):
                clusters.append([i for i in c_sets[str(t)][str(k)]])

            COLORS = {
                'norm_hamming' : ['orange', 'red', 'purple', 'blue', 'green'],
                'geom_hamming': ['purple', 'blue', 'orange', 'green', 'red'],
                'rms_hamming': ['orange', 'red', 'blue', 'green', 'purple'],

                'jaccard' : ['blue', 'orange', 'red', 'purple', 'green'],
                'geom_jaccard' : ['purple', 'red', 'blue', 'orange', 'green'],
                'rms_jaccard' : ['blue', 'orange', 'red', 'green', 'purple'],
            }

            # all candidates from cluster 0 should have COLORS[0] color and son on
            colors = []
            for i in range(num_candidates):
                for k in range(num_clusters):
                    if i in clusters[k]:
                        colors.append(COLORS[base_distance][k])
                        break

            plt.scatter(c_coordinates[:, 0],
                        c_coordinates[:, 1],
                        c=colors,
                        s=(c_radius*60)**2,
                        alpha=0.5
                        )

            plt.scatter(v_coordinates[:, 0],
                        v_coordinates[:, 1],
                        # c=categories_colors,
                        s=20,
                        c='gray',
                        alpha=0.1
                        )

            plt.axis('off')
            plt.savefig(f'images/euclidean_2d/k{num_clusters}/{method}_{base_distance}_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.png', dpi=300, bbox_inches='tight')
            plt.clf()

